function PlotTransitionPath(parms,gesol,prob_Nxz)
% timing:
% t=-1: old state
% t=0: state switches
% t>0: state remain unchanged

    if parms.Nz~=2
        error('PlotTransitionPath::code is designed for Nz=2.');
    end
    
    Tmax = 60;
    ts = -1:Tmax;
    
    NumScenario = 3;
    Nt = zeros(numel(ts), 2, NumScenario);
    phi_of_x = @(x)1./(1 + exp(-x));
    xt = zeros(numel(ts), 2, NumScenario);
    thetat = zeros(numel(ts), 2, NumScenario);
    y1t = zeros(numel(ts), 2, NumScenario);
    y5t = zeros(numel(ts), 2, NumScenario);
    entropy1t = zeros(numel(ts), 2, NumScenario);
    entropy5t = zeros(numel(ts), 2, NumScenario);
    Edm1t = zeros(numel(ts), 2, NumScenario);
    Edm5t = zeros(numel(ts), 2, NumScenario);
    logzetat = zeros(numel(ts), 2, NumScenario);
    
    logzeta_fun = @(N,phi) log(N + (1-N).*(phi.^(1-1/parms.chi))) ...
        + (parms.gamma-1)*(phi.*(1-N) + N);
    
    for IDX_Z = 1:2
        % state switches to IDX_Z
        switch IDX_Z
            case 1
                IDX_Z_INIT = 2; % 2->1
            case 2
                IDX_Z_INIT = 1; % 1->2
        end
        prob_N = sum(prob_Nxz(:,:,IDX_Z_INIT),2); prob_x = sum(prob_Nxz(:,:,IDX_Z_INIT),1);
        prob_N = prob_N(:)/sum(prob_N(:)); prob_x = prob_x(:)/sum(prob_x(:));
        cdf_N = cumsum(prob_N); cdf_x = cumsum(prob_x);
        for idx_scenario = 1:NumScenario
            switch idx_scenario
                case 1
                    % start at conditional means
                    N0 = sum(prob_N(:).*parms.grid_N(:));
                    x0 = sum(prob_x(:).*parms.grid_x(:));
                case 2
                    % start at P5 (low N, low phi)
                    Pval = 0.05;
                    N0 = InterpPctile(parms.grid_N,cdf_N,Pval);
                    x0 = InterpPctile(parms.grid_x,cdf_x,Pval);
                case 3
                    % start at P95 (high N, high phi)
                    Pval = 0.95;
                    N0 = InterpPctile(parms.grid_N,cdf_N,Pval);
                    x0 = InterpPctile(parms.grid_x,cdf_x,Pval);
            end
            [ts, Nt(:,IDX_Z,idx_scenario), xt(:,IDX_Z,idx_scenario), ...
                thetat(:,IDX_Z,idx_scenario), y1t(:,IDX_Z,idx_scenario), y5t(:,IDX_Z,idx_scenario),...
                entropy1t(:,IDX_Z,idx_scenario),entropy5t(:,IDX_Z,idx_scenario),...
                Edm1t(:,IDX_Z,idx_scenario),Edm5t(:,IDX_Z,idx_scenario)] = GetPath(parms,gesol,N0,x0,IDX_Z_INIT,IDX_Z,Tmax);
        end
    end
    
    Ut = 1 - Nt;
    phit = phi_of_x(xt);
    logzetat(:) = logzeta_fun(Nt,phi_of_x(xt));
    slopet = y5t - y1t;
    
    slopet_entropy = entropy5t - entropy1t;
    slopet_Edm = Edm5t - Edm1t;
    
    plot_vars = {Ut, phit, logzetat, thetat, y1t, entropy1t, Edm1t, ...
        y5t, entropy5t, Edm5t, slopet, slopet_entropy, slopet_Edm};
    var_names = {'U(t)','\phi(t)','log \zeta(t)','\theta(t)','y_t^{(1)}','y_t^{(1)}, entropy',...
        'y_t^{(1)}, other','y_t^{(5)}','y_t^{(5)}, entropy','y_t^{(5)}, other',...
        'slope(t)','slope(t), entropy','slope(t), other'};
    
    % plot
    for IDX_Z = 1:2
        
        switch IDX_Z
            case 1
                IDX_Z_INIT = 2;
            case 2
                IDX_Z_INIT = 1;
        end
    
        m = 3; n = 5;

        LineWidth = 1.5;

        figure
        
        for idxplot = 1:numel(plot_vars)
            
            subplot(m,n,idxplot); hold on;
            plot(ts, plot_vars{idxplot}(:,IDX_Z,1), '-', 'LineWidth', LineWidth)
            plot(ts, plot_vars{idxplot}(:,IDX_Z,2), '--', 'LineWidth', LineWidth)
            plot(ts, plot_vars{idxplot}(:,IDX_Z,3), ':', 'LineWidth', LineWidth)
            xlabel('t')
            ylabel(var_names{idxplot})
            title(['State switch: ',num2str(IDX_Z_INIT),'->',num2str(IDX_Z)])
            if idxplot==1
                legend({'start at cond means','start at P5','start at P95'})
                legend boxoff
            end
            
        end
    
    end
    
    % plot changes
    for IDX_Z = 1:2
        
        switch IDX_Z
            case 1
                IDX_Z_INIT = 2;
            case 2
                IDX_Z_INIT = 1;
        end
    
        m = 3; n = 5;

        LineWidth = 1.5;

        figure
        
        for idxplot = 1:numel(plot_vars)
            
            subplot(m,n,idxplot); hold on;
            plot(ts, plot_vars{idxplot}(:,IDX_Z,1) - plot_vars{idxplot}(1,IDX_Z,1), '-', 'LineWidth', LineWidth)
            plot(ts, plot_vars{idxplot}(:,IDX_Z,2) - plot_vars{idxplot}(1,IDX_Z,2), '--', 'LineWidth', LineWidth)
            plot(ts, plot_vars{idxplot}(:,IDX_Z,3) - plot_vars{idxplot}(1,IDX_Z,3), ':', 'LineWidth', LineWidth)
            xlabel('t')
            ylabel(['\Delta',var_names{idxplot}])
            title(['State switch: ',num2str(IDX_Z_INIT),'->',num2str(IDX_Z)])
            if idxplot==1
                legend({'start at cond means','start at P5','start at P95'})
                legend boxoff
            end
            
        end
    
    end

end

function pct_x = InterpPctile(grid_x,cdf_x,prob_target)

    idx_x_Pval = find(cdf_x(1:end-1)<=prob_target & prob_target<=cdf_x(2:end),1,'first');
    
    if cdf_x(idx_x_Pval)>=cdf_x(idx_x_Pval+1)
        pct_x = (grid_x(idx_x_Pval) + grid_x(idx_x_Pval+1))/2;
    else
        pct_x = interp1([cdf_x(idx_x_Pval) cdf_x(idx_x_Pval+1)],...
            [grid_x(idx_x_Pval) grid_x(idx_x_Pval+1)],prob_target);
    end

end

function [ts, N_path, x_path, theta_path, y1_path, y5_path, ...
    entropy1_path,entropy5_path,Edm1_path,Edm5_path] = GetPath(parms,gesol,N0,x0,idx_zold,idx_znew,T)
% shock happens at t=0
% t=-1 is really t=0-epsilon

    ts = 0:T;
    idx_zpath = zeros(size(ts));
    N_path = zeros(size(ts));
    x_path = zeros(size(ts));
    
    idx_zpath(:) = idx_znew;
    x_path(1) = x0;
    N_path(1) = N0;
    
    [NN_mesh,xx_mesh] = meshgrid(parms.grid_N(:),parms.grid_x(:));
    
    expected_z = parms.StateTransitionProbs*parms.grid_z(:);
    std_z = sqrt(parms.StateTransitionProbs*(parms.grid_z(:).^2) - expected_z.^2);
    
    for t = 2:numel(ts)
        
        N_path(t) = (1 - parms.s(idx_zpath(t-1)))*N_path(t-1) ...
            + interp2(NN_mesh,xx_mesh,gesol.f(:,:,idx_zpath(t-1))',...
                      N_path(t-1),x_path(t-1))*(1 - N_path(t-1));
                  
        x_path(t) = (1-parms.rho_x)*parms.x_bar + parms.rho_x*x_path(t-1) ...
            + interp1(parms.grid_N(:),parms.sigma_x(:,idx_zpath(t-1)),N_path(t-1))...
            *(parms.grid_z(idx_zpath(t)) - expected_z(idx_zpath(t-1)))/std_z(idx_zpath(t-1));
        
        if N_path(t)>parms.grid_N(end); N_path(t)=parms.grid_N(end); end
        if N_path(t)<parms.grid_N(1); N_path(t)=parms.grid_N(1); end
        
        if x_path(t)>parms.grid_x(end); x_path(t)=parms.grid_x(end); end
        if x_path(t)<parms.grid_x(1); x_path(t)=parms.grid_x(1); end
        
    end
    
    % add initial conditions at t=0-epsilon
    ts = [-1,ts];
    N_path = [N0,N_path];
    x_path = [x0,x_path];
    
    theta_path = zeros(size(ts));
    y1_path = zeros(size(ts));
    y5_path = zeros(size(ts));
    entropy1_path = zeros(size(ts));
    entropy5_path = zeros(size(ts));
    Edm1_path = zeros(size(ts));
    Edm5_path = zeros(size(ts));
    
    idx_y1 = find(gesol.real_bonds.maturities==12);
    idx_y5 = find(gesol.real_bonds.maturities==60);
    
    y1_Nxz = -log(gesol.real_bonds.prices(:,:,:,idx_y1));
    y5_Nxz = -log(gesol.real_bonds.prices(:,:,:,idx_y5))/5;
    entropy1_Nxz = -gesol.entropy.CondEntropy(:,:,:,idx_y1);
    Edm1_Nxz = -gesol.entropy.E_dlogM(:,:,:,idx_y1);
    entropy5_Nxz = -gesol.entropy.CondEntropy(:,:,:,idx_y5)/5;
    Edm5_Nxz = -gesol.entropy.E_dlogM(:,:,:,idx_y5)/5;
    
    theta_path(1) = interp2(NN_mesh,xx_mesh,gesol.theta(:,:,idx_zold)',N0,x0);
    y1_path(1) = interp2(NN_mesh,xx_mesh,y1_Nxz(:,:,idx_zold)',N0,x0);
    entropy1_path(1) = interp2(NN_mesh,xx_mesh,entropy1_Nxz(:,:,idx_zold)',N0,x0);
    Edm1_path(1) = interp2(NN_mesh,xx_mesh,Edm1_Nxz(:,:,idx_zold)',N0,x0);
    y5_path(1) = interp2(NN_mesh,xx_mesh,y5_Nxz(:,:,idx_zold)',N0,x0);
    entropy5_path(1) = interp2(NN_mesh,xx_mesh,entropy5_Nxz(:,:,idx_zold)',N0,x0);
    Edm5_path(1) = interp2(NN_mesh,xx_mesh,Edm5_Nxz(:,:,idx_zold)',N0,x0);
    
    theta_path(2:end) = interp2(NN_mesh,xx_mesh,gesol.theta(:,:,idx_znew)',N_path(2:end),x_path(2:end));
    y1_path(2:end) = interp2(NN_mesh,xx_mesh,y1_Nxz(:,:,idx_znew)',N_path(2:end),x_path(2:end));
    entropy1_path(2:end) = interp2(NN_mesh,xx_mesh,entropy1_Nxz(:,:,idx_znew)',N_path(2:end),x_path(2:end));
    Edm1_path(2:end) = interp2(NN_mesh,xx_mesh,Edm1_Nxz(:,:,idx_znew)',N_path(2:end),x_path(2:end));
    y5_path(2:end) = interp2(NN_mesh,xx_mesh,y5_Nxz(:,:,idx_znew)',N_path(2:end),x_path(2:end));
    entropy5_path(2:end) = interp2(NN_mesh,xx_mesh,entropy5_Nxz(:,:,idx_znew)',N_path(2:end),x_path(2:end));
    Edm5_path(2:end) = interp2(NN_mesh,xx_mesh,Edm5_Nxz(:,:,idx_znew)',N_path(2:end),x_path(2:end));

end

%% old

% function PlotTransitionPath(parms,gesol,prob_Nxz)
% % timing:
% % t=-1: old state
% % t=0: state switches
% % t>0: state remain unchanged
% 
%     if parms.Nz~=2
%         error('PlotTransitionPath::code is designed for Nz=2.');
%     end
%     
%     Tmax = 60;
%     ts = 0:Tmax;
%     
%     [NN_mesh,xx_mesh] = meshgrid(parms.grid_N(:),parms.grid_x(:));
%     
%     NumScenario = 3;
%     Nt = zeros(Tmax + 1, 2, NumScenario);
%     phi_of_x = @(x)1./(1 + exp(-x));
%     xt = zeros(Tmax + 1, 2, NumScenario);
%     logzetat = zeros(Tmax + 1, 2, NumScenario);
%     
%     logzeta_fun = @(N,phi) log(N + (1-N).*(phi.^(1-1/parms.chi))) ...
%         + (parms.gamma-1)*(phi.*(1-N) + N);
%     
%     for IDX_Z = 1:2
%         % state switches to IDX_Z
%         switch IDX_Z
%             case 1
%                 IDX_Z_INIT = 2; % 2->1
%             case 2
%                 IDX_Z_INIT = 1; % 1->2
%         end
%         prob_N = sum(prob_Nxz(:,:,IDX_Z_INIT),2); prob_x = sum(prob_Nxz(:,:,IDX_Z_INIT),1);
%         prob_N = prob_N(:)/sum(prob_N(:)); prob_x = prob_x(:)/sum(prob_x(:));
%         cdf_N = cumsum(prob_N); cdf_x = cumsum(prob_x);
%         expected_znext = parms.StateTransitionProbs(IDX_Z,:)*parms.grid_z(:);
%         znext = parms.grid_z(IDX_Z);
%         std_znext = sqrt(parms.StateTransitionProbs(IDX_Z,:)*(parms.grid_z(:).^2) - expected_znext^2);
%         shock_z = (znext - expected_znext)/std_znext;
%         for idx_scenario = 1:NumScenario
%             switch idx_scenario
%                 case 1
%                     % start at conditional means
%                     Nt(1,IDX_Z,idx_scenario) = sum(prob_N(:).*parms.grid_N(:));
%                     xt(1,IDX_Z,idx_scenario) = sum(prob_x(:).*parms.grid_x(:));
%                 case 2
%                     % start at P5 (low N, low phi)
%                     Pval = 0.05;
%                     idx_N_Pval = find(cdf_N(1:end-1)<=Pval & Pval<=cdf_N(2:end),1,'first');
%                     if cdf_N(idx_N_Pval)>=cdf_N(idx_N_Pval+1)
%                         Nt(1,IDX_Z,idx_scenario) = (parms.grid_N(idx_N_Pval) + parms.grid_N(idx_N_Pval+1))/2;
%                     else
%                         Nt(1,IDX_Z,idx_scenario) = interp1([cdf_N(idx_N_Pval);cdf_N(idx_N_Pval+1)],...
%                             [parms.grid_N(idx_N_Pval);parms.grid_N(idx_N_Pval+1)], Pval);
%                     end
%                     Pval = 0.05;
%                     idx_x_Pval = find(cdf_x(1:end-1)<=Pval & Pval<=cdf_x(2:end),1,'first');
%                     if cdf_x(idx_x_Pval)>=cdf_x(idx_x_Pval+1)
%                         xt(1,IDX_Z,idx_scenario) = (parms.grid_x(idx_x_Pval) + parms.grid_x(idx_x_Pval+1))/2;
%                     else
%                         xt(1,IDX_Z,idx_scenario) = interp1([cdf_x(idx_x_Pval);cdf_x(idx_x_Pval+1)],...
%                             [parms.grid_x(idx_x_Pval);parms.grid_x(idx_x_Pval+1)], Pval);
%                     end
%                 case 3
%                     % start at P95 (high N, high phi)
%                     Pval = 0.95;
%                     idx_N_Pval = find(cdf_N(1:end-1)<=Pval & Pval<=cdf_N(2:end),1,'first');
%                     if cdf_N(idx_N_Pval)>=cdf_N(idx_N_Pval+1)
%                         Nt(1,IDX_Z,idx_scenario) = (parms.grid_N(idx_N_Pval) + parms.grid_N(idx_N_Pval+1))/2;
%                     else
%                         Nt(1,IDX_Z,idx_scenario) = interp1([cdf_N(idx_N_Pval);cdf_N(idx_N_Pval+1)],...
%                             [parms.grid_N(idx_N_Pval);parms.grid_N(idx_N_Pval+1)], Pval);
%                     end
%                     Pval = 0.95;
%                     idx_x_Pval = find(cdf_x(1:end-1)<=Pval & Pval<=cdf_x(2:end),1,'first');
%                     if cdf_x(idx_x_Pval)>=cdf_x(idx_x_Pval+1)
%                         xt(1,IDX_Z,idx_scenario) = (parms.grid_lambda(idx_x_Pval) + parms.grid_lambda(idx_x_Pval+1))/2;
%                     else
%                         xt(1,IDX_Z,idx_scenario) = interp1([cdf_x(idx_x_Pval);cdf_x(idx_x_Pval+1)],...
%                             [parms.grid_x(idx_x_Pval);parms.grid_x(idx_x_Pval+1)], Pval);
%                     end
%             end
%             for t = 2:Tmax+1
%                 Nt(t,IDX_Z,idx_scenario) = (1-parms.s(IDX_Z))*Nt(t-1,IDX_Z,idx_scenario) + interp2(NN_mesh,xx_mesh,gesol.f(:,:,IDX_Z)',...
%                     Nt(t-1,IDX_Z,idx_scenario),xt(t-1,IDX_Z,idx_scenario))*(1 - Nt(t-1,IDX_Z,idx_scenario));
%                 xt(t,IDX_Z,idx_scenario) = (1-parms.rho_x)*parms.x_bar + parms.rho_x*xt(t-1,IDX_Z,idx_scenario) + interp1(parms.grid_N(:),parms.sigma_x(:,IDX_Z),...
%                     Nt(t-1,IDX_Z,idx_scenario))*shock_z;
%                 if Nt(t,IDX_Z,idx_scenario)<parms.grid_N(1)
%                     Nt(t,IDX_Z,idx_scenario)=parms.grid_N(1);
%                 elseif Nt(t,IDX_Z,idx_scenario)>parms.grid_N(end)
%                     Nt(t,IDX_Z,idx_scenario)=parms.grid_N(end);
%                 end
%                 if xt(t,IDX_Z,idx_scenario)<parms.grid_x(1)
%                     xt(t,IDX_Z,idx_scenario)=parms.grid_x(1);
%                 elseif xt(t,IDX_Z,idx_scenario)>parms.grid_x(end)
%                     xt(t,IDX_Z,idx_scenario)=parms.grid_x(end);
%                 end
%             end
%         end
%     end
%     
%     logzetat = logzeta_fun(Nt,phi_of_x(xt));
%     
%     % plot
%     m = 2; n = 3; idxplot = 0;
%     
%     LineWidth = 1.5;
%     
%     figure
%     
%     idxplot = idxplot + 1; subplot(m,n,idxplot);
%     hold on
%     plot(ts, Nt(:,1,1), '-', 'LineWidth', LineWidth)
%     plot(ts, Nt(:,1,2), '--', 'LineWidth', LineWidth)
%     plot(ts, Nt(:,1,3), ':', 'LineWidth', LineWidth)
%     xlabel('t')
%     ylabel('N(t)')
%     title('State switch: 2->1')
%     legend({'start at cond means','start at P5','start at P95'})
%     legend boxoff
%     
%     idxplot = idxplot + 1; subplot(m,n,idxplot);
%     hold on
%     plot(ts, phi_of_x(xt(:,1,1)), '-', 'LineWidth', LineWidth)
%     plot(ts, phi_of_x(xt(:,1,2)), '--', 'LineWidth', LineWidth)
%     plot(ts, phi_of_x(xt(:,1,3)), ':', 'LineWidth', LineWidth)
%     xlabel('t')
%     ylabel('\phi(t)')
%     title('State switch: 2->1')
%     
%     idxplot = idxplot + 1; subplot(m,n,idxplot);
%     hold on
%     plot(ts, logzetat(:,1,1), '-', 'LineWidth', LineWidth)
%     plot(ts, logzetat(:,1,2), '--', 'LineWidth', LineWidth)
%     plot(ts, logzetat(:,1,3), ':', 'LineWidth', LineWidth)
%     xlabel('t')
%     ylabel('log \zeta(t)')
%     title('State switch: 2->1')
%     
%     idxplot = idxplot + 1; subplot(m,n,idxplot);
%     hold on
%     plot(ts, Nt(:,2,1), '-', 'LineWidth', LineWidth)
%     plot(ts, Nt(:,2,2), '--', 'LineWidth', LineWidth)
%     plot(ts, Nt(:,2,3), ':', 'LineWidth', LineWidth)
%     xlabel('t')
%     ylabel('N(t)')
%     title('State switch: 1->2')
%     legend({'start at cond means','start at P5','start at P95'})
%     legend boxoff
%     
%     idxplot = idxplot + 1; subplot(m,n,idxplot);
%     hold on
%     plot(ts, phi_of_x(xt(:,2,1)), '-', 'LineWidth', LineWidth)
%     plot(ts, phi_of_x(xt(:,2,2)), '--', 'LineWidth', LineWidth)
%     plot(ts, phi_of_x(xt(:,2,3)), ':', 'LineWidth', LineWidth)
%     xlabel('t')
%     ylabel('\phi(t)')
%     title('State switch: 1->2')
%     
%     idxplot = idxplot + 1; subplot(m,n,idxplot);
%     hold on
%     plot(ts, logzetat(:,2,1), '-', 'LineWidth', LineWidth)
%     plot(ts, logzetat(:,2,2), '--', 'LineWidth', LineWidth)
%     plot(ts, logzetat(:,2,3), ':', 'LineWidth', LineWidth)
%     xlabel('t')
%     ylabel('log \zeta(t)')
%     title('State switch: 1->2')
% 
% end